#include "stdafx.h"
#include "RSA_System.h"

//#define CHINESE
//#define SCREEN_LOG

RSA_System::RSA_System(): e(0)
{
	create_RSA_mod();
	create_e();
}

RSA_System::RSA_System(large_uint p,large_uint q, large_uint d): e(0)
{
	this->p = p;
	this->q = q;
	this->d = d;
}

large_uint RSA_System::get_RSA_mod() const
{
	return get_n();
}

large_uint RSA_System::get_RSA_exp() const
{
	if(e == 0) ((RSA_System*) this)->e = d.inv(get_totient_n());
	return e;
}
void RSA_System::create_RSA_mod()
{
	do
	{
		do
		{
			// get prime p which length | p | = | n | / 2
			/** /
			large_uint p = large_uint::random();
			large_uint temp = (
									large_uint::one()<<RSA_MOD_SIZE/2 - 2  
								) 
								- large_uint::one();//mask contains RSA_MOD_SIZE - 2 bits '1'
			p&=temp;
			p<<=1;
			p+=large_uint::one();
			temp = large_uint::one()<<
					(
						RSA_MOD_SIZE / 2 - 1 //set first RSA_MOD_SIZE-1-th bit to 1
					);
			p|=temp;
			/**/
			p = (
					(
						(
							large_uint::random() &  //set RSA_MOD_SIZE-2 random bits
							(
								(
									large_uint::one()<<RSA_MOD_SIZE/2 - 2  
								) 
								- large_uint::one() //mask contains RSA_MOD_SIZE - 2 bits '1'
							)
							<<1
						) 
						+ large_uint::one() //set last zero-bit to 1, cause p should be odd
					) 
					| large_uint::one()<<
					(
						RSA_MOD_SIZE / 2 - 1 //set first RSA_MOD_SIZE/2-1-th bit to 1
					)
				);
			/**/
#ifdef SCREEN_LOG
			std::cout<<"p: "<<p<<"\n";
#endif
		}
		while(!probably_prime(p)); 
		do
		{
       		// get prime q which length | q | = | n | / 2
			q = (
					(
						(
							large_uint::random() &  //set RSA_MOD_SIZE-2 random bits
							(
								(
									large_uint::one()<<RSA_MOD_SIZE/2 - 2  
								) 
								- large_uint::one() //mask contains RSA_MOD_SIZE - 2 bits '1'
							)
							<<1
						) 
						+ large_uint::one() //set last zero-bit to 1, cause p should be odd
					) 
					| large_uint::one()<<
					(
						RSA_MOD_SIZE / 2 - 1 //set first RSA_MOD_SIZE/2-1-th bit to 1
					)
				);
#ifdef SCREEN_LOG
			std::cout<<"q: "<<q<<"\n";
#endif
		}
		while(!probably_prime(q));
#ifdef SCREEN_LOG
			std::cout<<"\n\n1) p: "<<p<<"\n";
			std::cout<<"2) q: "<<q<<"\n";
			std::cout<<"p size is : "<<p.size() <<" q size is : "<<q.size();
			std::cout<<" n size is : "<< get_n().size();
#endif
	}
	while(get_n().size()!=RSA_MOD_SIZE); // n should be with length | n | = RSA_MOD_SIZE
}

void RSA_System::create_e()
{
	//large_uint e;
	do
	{
		e = large_uint::random()%(get_n() - 3) + 2; // e = 2..p-2
	}
	while(large_uint::gcd(e,get_n())!=large_uint::one()); // GCD(e, n) should be 1

	d = e.inv(get_totient_n());
}
large_uint RSA_System::get_signature(const large_uint &x) const
{
	if(x>=get_n()) throw "Signatured text should be less than RSA mod";
	large_uint s;
#ifdef CHINESE
	large_uint dp = d % (p - large_uint::one()), // dp = d mod (p - 1)
	           dq = d % (q - large_uint::one()), // dq = d mod (q - 1)
			   p1 = p.inv(q),
			   q1 = q.inv(p);
	large_uint xp = large_uint::pow_m((x % p), dp, p), //xp = (h mod p)^dp mod p
	           xq = large_uint::pow_m((x % q), dq, q); //xq = (h mod q)^dq mod q
	s =  large_uint::add_m(xp*p*p1, xq*q*q1, get_n());
#else
	s = large_uint::pow_m(x, d, get_n());
#endif /* CHINESE */
	return s;
}
large_uint RSA_System::get_n() const
{
	return p*q;
}
large_uint RSA_System::get_totient_n() const
{
	return (p - large_uint::one()) * (q - large_uint::one()); // (p - 1) * (q - 1)
}

RSA_Verifyer::RSA_Verifyer(large_uint n, large_uint e)
	:_n(n),_e(e)
{
}
bool RSA_Verifyer::verify(const large_uint &x,const large_uint &s)
{
	return x==large_uint::pow_m(s, _e, _n);
}